import hydra
import time
from omegaconf import DictConfig, OmegaConf

from envs.multiagent_env import MultiAgentEnv
from utils.rollout import rollout, process_learning_data
from utils.metrics import Metrics
from policies import construct_policy_map
from policies.utils.transition import Transition


@hydra.main(version_base=None, config_path="../configs", config_name="config")
def train(config):
    start_time = time.time()
    print(OmegaConf.to_yaml(config))
    # Create environment
    env = MultiAgentEnv(config)

    # Generate policy mappings
    policies = construct_policy_map(config.exp_config)
    episode_returns = {}
    metrics = Metrics(policies)

    # Start Training Loop
    for i in range(config.num_train_iterations):
        # Collect Data and Process it for Learning
        results = rollout(env, policies,
                          num_episodes=config.num_episodes_per_iter,
                          metrics=metrics,
                          logdir=config.logdir,
                          data_processing_function=process_learning_data
                          )
        print('episode returns', results)

        # Decentralized Learning Phase
        for policy in policies.values():
            policy.learn()

        # Report Metrics
        for agent, reward in results.items():
            if agent not in episode_returns:
                episode_returns[agent] = []
            episode_returns[agent].append(reward)

        # Check for early stopping
        if config.early_stopping:
            if metrics.check_early_stopping(config.early_stopping_iterations):
                break
    
    print('episode returns', episode_returns)
    metrics.report()

    # Close environment
    env.close()
    print("Training completed in: ", time.time() - start_time)

if __name__ == '__main__':
    train()